from absl import logging as absl_logging
absl_logging.set_verbosity(absl_logging.ERROR)
import gym
import gymnasium 
import os
import numpy as np
import warnings


#Action observation wrapper: WIP
class LastActionRewardWrapper(gymnasium.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gymnasium.spaces.Dict(obs=env.observation_space['obs'], 
                                                  last_act=gymnasium.spaces.Box(low=0, high=1, shape=(env.action_space.n,), dtype=np.int32), 
                                                  last_rew=gymnasium.spaces.Box(low=-float('inf'), high=float('inf'), shape=(1,), dtype=np.float32))
        self.prev_action = None
        self.prev_reward = None

    def reset(self,seed=None,options=None):
        obs,info = self.env.reset(seed=seed,options=options)
        self.prev_action = None
        self.prev_reward = None
        obs['last_act']=np.zeros((self.action_space.n,),dtype=np.int32)
        obs['last_rew']=np.zeros((1,),dtype=np.float32)
        return obs,info

    def step(self, action):
        obs, reward, term, trunc, info = self.env.step(action)
        self.prev_action = np.zeros((self.env.action_space.n,),dtype=np.int32)
        self.prev_action[action] = 1
        self.prev_reward = np.array([reward])
        obs['last_act']=self.prev_action
        obs['last_rew']=self.prev_reward
        return obs, reward, term, trunc, info
    

class MemoryMaze(gymnasium.Env):

    def __init__(self, env_name,render_mode=None):
        self.env_name=env_name
        self._env=gym.make(env_name)
        self.render_mode=render_mode
        self.observation_space=gymnasium.spaces.Dict(obs=gymnasium.spaces.Box(self._env.observation_space.low,
                                                              self._env.observation_space.high,self._env.observation_space.shape,dtype=np.uint8))
        self.action_space=gymnasium.spaces.Discrete(self._env.action_space.n)
        self.last_info={}
        self.current_obs=None
    
    def reset(self,seed=None,options=None):
        if seed is not None:  #WIP: will include seeding after the repo is updated
            self._env=gym.make(self.env_name,seed=seed)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            #if seed is not None:
            #    self._env.seed(seed)
            self.current_obs=self._env.reset()
            return {'obs':np.array(self.current_obs,dtype=np.uint8)},self.last_info
    
    def render(self,mode=None):
        return self.current_obs
    
    def step(self,action):
        obs,reward, term, info=self._env.step(action)
        self.current_obs=obs
        if term: #If episode is over, store info to make it compaitible with AutoResetWrapper
            self.last_info=info
        return {'obs':np.array(obs,dtype=np.uint8)},reward,term,term,info 


def create_memory_maze_env(**env_config):
    return LastActionRewardWrapper(MemoryMaze(env_name=env_config['name']))


if __name__=='__main__':
    env = gym.make('memory_maze:MemoryMaze-9x9-v0')
    env=MemoryMaze('memory_maze:MemoryMaze-9x9-v0')
    obs,info=env.reset()

        